import visual_behavior_glm
import visual_behavior_glm.src.GLM_params as glm_params
import visual_behavior_glm.src.GLM_analysis_tools as gat
from visual_behavior_glm.src.glm import GLM
import matplotlib.pyplot as plt
import visual_behavior.data_access.loading as loading
import visual_behavior.database as db
import plotly.express as px
import pandas as pd
import numpy as np
import os
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import seaborn as sns
%matplotlib inline
results_all = gat.retrieve_results(results_type='full')
results_all['glm_version'].unique()
#use v4
rs = gat.retrieve_results(search_dict = {'glm_version': '4_L2_optimize_by_cell'}, results_type='summary')
len(rs)
rs['identifier'] = rs['ophys_experiment_id'].astype(str) + '_' + rs['cell_specimen_id'].astype(str)
rs
model_output_type = 'variance_explained'
ve = rs.pivot(index='identifier',columns='dropout',values=model_output_type).reset_index()
ve
cells_to_include = ve[ve['Full']>0.01].identifier.values
order = np.argsort(ve[ve.identifier.isin(cells_to_include)==True]['Full'])
cell_order = cells_to_include[order]
len(cells_to_include)
model_output_type = 'fraction_change_from_full'
rsp = rs.pivot(index='identifier',columns='dropout',values=model_output_type).reset_index()
rsp
tmp = ve.rename(columns={'Full':'varience_explained_full_model'})
rsp = rsp.merge(tmp[['identifier','varience_explained_full_model']], on=['identifier'])
rsp = rsp[rsp.identifier.isin(cells_to_include)==True]
rspm = rsp.merge(rs[['identifier','cre_line','session_type','equipment_name']].drop_duplicates(),left_on='identifier',right_on='identifier',how='inner')
rspm
def map_session_types(session_type):
session_id = session_type[6:7]
return session_id
rspm['session_id'] = rspm['session_type'].map(lambda st:map_session_types(st))
rspm['session_id'].unique()
# save = False
# if save:
# rspm.to_csv('/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/ophys_glm/fraction_change_var_explained_v_4_L2_fixed_lambda=1_2020.08.09.csv', index=False)
cols_for_clustering = [col for col in rspm.columns if col not in ['identifier','cre_line','session_type','equipment_name', 'session_id']]
cols_for_clustering = [col for col in cols_for_clustering if col not in ['image0','image1','image2','image3',
'image4','image5','image6','image7',
'visual']]
cols_for_clustering
cols_for_clustering = [
'omissions',
'all-images',
'image_expectation',
'change',
'hits',
'misses',
'correct_rejects',
'false_alarms',
'post_lick_bouts',
'post_licks',
'pre_lick_bouts',
'pre_licks',
'rewards',
'pupil',
'running',
'time',
'model_bias',
'model_omissions1',
'model_task0',
'model_timing1D',
]
rspm[cols_for_clustering]
feature_matrix = rspm[cols_for_clustering]
fig, ax = plt.subplots(figsize=(6,10))
ax = sns.heatmap(feature_matrix, vmin=-0.5, vmax=0.5, center=0, cmap='RdBu_r', ax=ax, cbar_kws={'label':model_output_type})
ax.set_ylabel('cells')
ax.set_title('GLM feature matrix')
feature_matrix = rspm.sort_values('varience_explained_full_model').reset_index()[cols_for_clustering]
fig, ax = plt.subplots(figsize=(6,10))
ax = sns.heatmap(feature_matrix, vmin=-0.5, vmax=0.5, cmap='RdBu_r', center=0, ax=ax, cbar_kws={'label':model_output_type})
ax.set_ylabel('cells')
ax.set_title('GLM feature matrix')
feature_matrix = rspm.sort_values('omissions').reset_index()[cols_for_clustering]
fig, ax = plt.subplots(figsize=(6,10))
ax = sns.heatmap(feature_matrix, vmin=-0.5, vmax=0.5, cmap='RdBu_r', ax=ax, cbar_kws={'label':model_output_type})
ax.set_ylabel('cells')
ax.set_title('GLM feature matrix')
fig, ax = plt.subplots(figsize=(4,4))
ax = sns.scatterplot(data=rspm, x='omissions', y='all-images', hue='cre_line', ax=ax)
fig, ax = plt.subplots(figsize=(4,4))
ax = sns.scatterplot(data=rspm, x='omissions', y='pupil', hue='cre_line', ax=ax)
fig, ax = plt.subplots(figsize=(4,4))
ax = sns.scatterplot(data=rspm, x='pupil', y='running', hue='cre_line', ax=ax)
colors = sns.color_palette()
colors = [colors[0], colors[2], colors[3]]
cre_lines = np.sort(rspm.cre_line.unique())
for metric in cols_for_clustering:
fig, ax = plt.subplots(figsize=(6,4))
sns.pointplot(data=rspm, x='session_id', y=metric, hue='cre_line', hue_order=cre_lines, palette=colors, ax=ax)
n_features = len(cols_for_clustering)
n_components = len(cols_for_clustering)
pca = PCA(n_components=n_components)
pca_result = pca.fit_transform(rspm[cols_for_clustering].values)
rspm['pc1'] = pca_result[:,0]
rspm['pc2'] = pca_result[:,1]
rspm['pc3'] = pca_result[:,2]
print('Explained variation per principal component: {}'.format(pca.explained_variance_ratio_))
fig,ax=plt.subplots()
ax.plot(
np.arange(n_components),
pca.explained_variance_ratio_,
'o-k'
)
ax.set_xlabel('PC number')
ax.set_ylabel('variance explained')
ax.set_title('first 8 PCs explain >95% of the variance')
np.cumsum(pca.explained_variance_ratio_)
np.searchsorted(np.cumsum(pca.explained_variance_ratio_), .90)
np.searchsorted(np.cumsum(pca.explained_variance_ratio_), .95)
fig, ax = plt.subplots(figsize=(6,6))
ax = sns.heatmap(pca.components_, vmin=-1, vmax=1, cmap='RdBu_r', ax=ax, square=True,
robust=True, cbar_kws={"drawedges": False, "shrink": 0.7, "label": 'weight'})
ax.set_ylabel('principal components')
ax.set_xlabel('features')
# ax.set_title('principal axes in feature space \n(directions of maximum variance in the data)')
ax.set_ylim(0, n_components)
ax.set_xticklabels(cols_for_clustering, rotation=90);
pca.components_.shape
fig,ax=plt.subplots(figsize=(12,4))
N_PCs = 8
for PC in range(N_PCs):
ax.plot(pca.components_[PC,:])
ax.legend(np.arange(N_PCs), title='PC', bbox_to_anchor=(1,1))
ax.axhspan(1/np.sqrt(32), -1/np.sqrt(32), zorder=-np.inf, alpha=0.25)
ax.set_xticks(np.arange(len(cols_for_clustering)))
ax.set_xticklabels(cols_for_clustering, rotation=45, ha='right')
ax.set_ylabel('weight')
ax.set_ylim(-1.1,1.1)
fig.tight_layout()
fig,ax=plt.subplots(figsize=(12,4))
for PC in range(8,20):
ax.plot(pca.components_[PC,:])
ax.legend(np.arange(10,21), title='PC', bbox_to_anchor=(1,1))
ax.axhspan(1/np.sqrt(32), -1/np.sqrt(32), zorder=-np.inf, alpha=0.25)
ax.set_xticks(np.arange(len(cols_for_clustering)))
ax.set_xticklabels(cols_for_clustering, rotation=45, ha='right')
ax.set_ylabel('??')
ax.set_ylim(-1.1,1.1)
fig.tight_layout()
fig, ax = plt.subplots(figsize=(6,6))
ax = sns.heatmap(pca.get_covariance(), vmin=-0.2, vmax=0.2, cmap='RdBu_r', ax=ax, square=True,
robust=True, cbar_kws={"drawedges": False, "shrink": 0.7, "label": 'covariance'})
ax.set_title('covariance matrix')
ax.set_ylim(0, n_features)
ax.set_xticklabels(cols_for_clustering, rotation=90);
ax.set_yticklabels(cols_for_clustering, rotation=0);
pca_result[np.argsort(pca_result[:,0])]
fig, ax = plt.subplots(figsize=(6,10))
ax = sns.heatmap(pca_result, cmap='RdBu_r', ax=ax,
robust=True, cbar_kws={"drawedges": False, "shrink": 0.7, "label": 'activation'})
ax.set_ylabel('cells')
ax.set_xlabel('PC')
ax.set_title('% variance explained in PC space')
ax.set_ylim(0, pca_result.shape[0])
ax.set_xlim(0, pca_result.shape[1])
ax.set_xticks(np.arange(0, pca_result.shape[1]));
# ax.set_xticklabels(cols_for_clustering, rotation=90);
fig, ax = plt.subplots(figsize=(6,10))
ax = sns.heatmap(pca_result[np.argsort(pca_result[:,0])], cmap='RdBu_r', ax=ax,
robust=True, cbar_kws={"drawedges": False, "shrink": 0.7, "label": 'activation'})
ax.set_ylabel('cells')
ax.set_xlabel('PC')
ax.set_title('% variance explained in PC space')
ax.set_ylim(0, pca_result.shape[0])
ax.set_xlim(0, pca_result.shape[1])
ax.set_xticks(np.arange(0, pca_result.shape[1]));
# ax.set_xticklabels(cols_for_clustering, rotation=90);
# pc1 and pc2 columns in rspm correspond to pca_results for those PCs - is this the 'score' per cell?
fig,ax = plt.subplots(1,2,figsize=(12,6))
ax[0] = sns.scatterplot(data=rspm, x="pc1", y="pc2", hue="cre_line",
palette=sns.color_palette("hls", 3), legend="full", alpha=0.3, ax=ax[0])
# ax[0].set_xlim(-5,10)
# ax[0].set_ylim(-5,10)
ax[1] = sns.scatterplot(data=rspm, x="pc2", y="pc3", hue="cre_line",
palette=sns.color_palette("hls", 3), legend="full", alpha=0.3, ax=ax[1])
# ax[1].set_xlim(-5,10)
# ax[1].set_ylim(-5,10)
pca_result_df = pd.DataFrame(pca_result, index=rspm.identifier)
pca_result_df['cre_line'] = rspm['cre_line'].values
# pc1 and pc2 columns in rspm correspond to pca_results for those PCs - is this the 'score' per cell?
PC1 = 0
PC2 = 1
PC3 = 3
PC4 = 4
fig,ax = plt.subplots(1, 3, figsize=(15,5))
ax = ax.ravel()
i=0
ax[i] = sns.scatterplot(data=pca_result_df, x=PC1, y=PC2, hue="cre_line",
palette=sns.color_palette("hls", 3), legend="full", alpha=0.3, ax=ax[i])
# ax[i].set_xlim(-100,100)
# ax[i].set_ylim(-100,100)
i+=1
ax[i] = sns.scatterplot(data=pca_result_df, x=PC2, y=PC3, hue="cre_line",
palette=sns.color_palette("hls", 3), legend="full", alpha=0.3, ax=ax[i])
# ax[i].set_xlim(-100,100)
# ax[i].set_ylim(-100,100)
i+=1
ax[i] = sns.scatterplot(data=pca_result_df, x=PC3, y=PC4, hue="cre_line",
palette=sns.color_palette("hls", 3), legend="full", alpha=0.3, ax=ax[i])
# ax[i].set_xlim(-100,100)
# ax[i].set_ylim(-100,100)
fig.tight_layout()
query_string = '''pc1>-100 and pc1<100 and pc2>-100 and pc2<100 and pc3>-100 and pc3<100'''
fig = px.scatter_3d(
rspm.query(query_string),
x='pc1',
y='pc2',
z='pc3',
color='cre_line',
)
fig.update_traces(
marker=dict(
size=3,
opacity=0.25
)
)
fig.update_layout(
margin=dict(l=30, r=30, t=10, b=10),
width=1200,
height=1000,
)
# fig.write_html("/home/dougo/code/dougollerenshaw.github.io/figures_to_share/2020.08.09_PCA_on_GLM.html")
# fig.show()
# # pc1 and pc2 columns in rspm correspond to pca_results for those PCs - is this the 'score' per cell?
# fig,ax = plt.subplots(n_components, n_components, figsize=(20,20))
# ax = ax.ravel()
# i = 0
# for PC1 in range(n_components):
# for PC2 in range(n_components):
# ax[i] = sns.scatterplot(data=pca_result_df, x=PC1, y=PC2, hue="cre_line",
# palette=sns.color_palette("hls", 3), legend="full", alpha=0.3, ax=ax[i])
# ax[i].set_xlim(-100,100)
# ax[i].set_ylim(-100,100)
# # ax[1] = sns.scatterplot(data=rspm, x="pc2", y="pc3", hue="cre_line",
# # palette=sns.color_palette("hls", 3), legend="full", alpha=0.3, ax=ax[1])
# # ax[1].set_xlim(-100,100)
# # ax[1].set_ylim(-100,100)
feature_matrix = rspm[cols_for_clustering]
fig, ax = plt.subplots(figsize=(6,10))
ax = sns.heatmap(feature_matrix, vmin=-0.5, vmax=0.5, cmap='RdBu_r', ax=ax, cbar_kws={'label':model_output_type})
ax.set_ylabel('cells')
ax.set_title('GLM feature matrix')
fig, ax = plt.subplots(figsize=(4,4))
ax = sns.scatterplot(data=rspm, x='omissions', y='all-images', hue='cre_line', ax=ax)
# ax.set_xlim(-100, 100)
# ax.set_ylim(-100, 100)
fig, ax = plt.subplots(figsize=(4,4))
ax = sns.scatterplot(data=rspm, x='omissions', y='running', hue='cre_line', ax=ax)
# ax.set_xlim(-100, 100)
# ax.set_ylim(-100, 100)
fig, ax = plt.subplots(figsize=(4,4))
ax = sns.scatterplot(data=rspm, x='omissions', y='pupil', hue='cre_line', ax=ax)
# ax.set_xlim(-100, 100)
# ax.set_ylim(-100, 100)
kmeans = KMeans(n_clusters=10)
kmeans_result = kmeans.fit_predict(feature_matrix)
rspm['kmeans_result_features'] = kmeans_result
rspm['kmeans_result_features'].value_counts()
fig, ax = plt.subplots(figsize=(6,10))
ax = sns.heatmap(pca_result, cmap='RdBu_r', ax=ax,
robust=True, cbar_kws={"drawedges": False, "shrink": 0.7, "label": '?'})
ax.set_ylabel('cells')
ax.set_xlabel('PC')
ax.set_title('')
ax.set_ylim(0, pca_result.shape[0])
ax.set_xlim(0, pca_result.shape[1])
ax.set_xticks(np.arange(0, pca_result.shape[1]));
# ax.set_xticklabels(cols_for_clustering, rotation=90);
kmeans = KMeans(n_clusters=10)
kmeans_result = kmeans.fit_predict(pca_result)
rspm['kmeans_result'] = kmeans_result
rspm['kmeans_result'].value_counts()
"Then we applied consensus clustering to the PCs, by running K-means using the PCs 100 times until reaching a stable co-clustering association matrix, where each entry represents the probability of two units belonging to the same cluster." - Xiaoxuan paper
fig, ax = plt.subplots(figsize=(6,6))
ax = sns.heatmap(pca.components_, vmin=-1, vmax=1, cmap='RdBu_r', ax=ax, square=True,
robust=True, cbar_kws={"drawedges": False, "shrink": 0.7, "label": 'contribution'})
ax.set_ylabel('principal components')
ax.set_xlabel('GLM features')
ax.set_title('principal axes in feature space \n(directions of maximum variance in the data)')
ax.set_ylim(0, n_components)
ax.set_xticklabels(cols_for_clustering, rotation=90);
kmeans_result
kmeans = KMeans(n_clusters=10)
kmeans_result = kmeans.fit_predict(pca.components_)
kmeans_result
# import hierarchical clustering libraries
import scipy.cluster.hierarchy as sch
from sklearn.cluster import AgglomerativeClustering
# create clusters
hc = AgglomerativeClustering(n_clusters=4, affinity = 'euclidean', linkage = 'ward')
# save clusters for chart
y_hc = hc.fit_predict(rspm[cols_for_clustering])
rspm['hc'] = y_hc
rspm['hc'].value_counts()
# create clusters
hc = AgglomerativeClustering(n_clusters=4, affinity = 'euclidean', linkage = 'ward')
# save clusters for chart
y_hc = hc.fit_predict(pca_results)
rspm['hc_pca'] = y_hc
rspm['hc_pca'].value_counts()